{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'array' from 'numpy.core' (unknown location)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[1], line 2\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpreprocessing\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m StandardScaler\n\u001b[0;32m      4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m svm\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\gpucc\\lib\\site-packages\\matplotlib\\__init__.py:159\u001b[0m\n\u001b[0;32m    155\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpackaging\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mversion\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m parse \u001b[38;5;28;01mas\u001b[39;00m parse_version\n\u001b[0;32m    157\u001b[0m \u001b[38;5;66;03m# cbook must import matplotlib only within function\u001b[39;00m\n\u001b[0;32m    158\u001b[0m \u001b[38;5;66;03m# definitions, so it is safe to import from it here.\u001b[39;00m\n\u001b[1;32m--> 159\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _api, _version, cbook, _docstring, rcsetup\n\u001b[0;32m    160\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcbook\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m sanitize_sequence\n\u001b[0;32m    161\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MatplotlibDeprecationWarning\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\gpucc\\lib\\site-packages\\matplotlib\\rcsetup.py:28\u001b[0m\n\u001b[0;32m     26\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackends\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m BackendFilter, backend_registry\n\u001b[0;32m     27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcbook\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ls_mapper\n\u001b[1;32m---> 28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcolors\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Colormap, is_color_like\n\u001b[0;32m     29\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_fontconfig_pattern\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m parse_fontconfig_pattern\n\u001b[0;32m     30\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_enums\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m JoinStyle, CapStyle\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\gpucc\\lib\\site-packages\\matplotlib\\colors.py:57\u001b[0m\n\u001b[0;32m     55\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mmpl\u001b[39;00m\n\u001b[0;32m     56\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m---> 57\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _api, _cm, cbook, scale\n\u001b[0;32m     58\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_color_data\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m BASE_COLORS, TABLEAU_COLORS, CSS4_COLORS, XKCD_COLORS\n\u001b[0;32m     61\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01m_ColorMapping\u001b[39;00m(\u001b[38;5;28mdict\u001b[39m):\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\gpucc\\lib\\site-packages\\matplotlib\\scale.py:22\u001b[0m\n\u001b[0;32m     20\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mmpl\u001b[39;00m\n\u001b[0;32m     21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _api, _docstring\n\u001b[1;32m---> 22\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mticker\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[0;32m     23\u001b[0m     NullFormatter, ScalarFormatter, LogFormatterSciNotation, LogitFormatter,\n\u001b[0;32m     24\u001b[0m     NullLocator, LogLocator, AutoLocator, AutoMinorLocator,\n\u001b[0;32m     25\u001b[0m     SymmetricalLogLocator, AsinhLocator, LogitLocator)\n\u001b[0;32m     26\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtransforms\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Transform, IdentityTransform\n\u001b[0;32m     29\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mScaleBase\u001b[39;00m:\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\gpucc\\lib\\site-packages\\matplotlib\\ticker.py:144\u001b[0m\n\u001b[0;32m    142\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mmpl\u001b[39;00m\n\u001b[0;32m    143\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _api, cbook\n\u001b[1;32m--> 144\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m transforms \u001b[38;5;28;01mas\u001b[39;00m mtransforms\n\u001b[0;32m    146\u001b[0m _log \u001b[38;5;241m=\u001b[39m logging\u001b[38;5;241m.\u001b[39mgetLogger(\u001b[38;5;18m__name__\u001b[39m)\n\u001b[0;32m    148\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTickHelper\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFormatter\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFixedFormatter\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[0;32m    149\u001b[0m            \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNullFormatter\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFuncFormatter\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFormatStrFormatter\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[0;32m    150\u001b[0m            \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mStrMethodFormatter\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mScalarFormatter\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLogFormatter\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    156\u001b[0m            \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMultipleLocator\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMaxNLocator\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAutoMinorLocator\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[0;32m    157\u001b[0m            \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mSymmetricalLogLocator\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAsinhLocator\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLogitLocator\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\gpucc\\lib\\site-packages\\matplotlib\\transforms.py:46\u001b[0m\n\u001b[0;32m     43\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmath\u001b[39;00m\n\u001b[0;32m     45\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m---> 46\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlinalg\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m inv\n\u001b[0;32m     48\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _api\n\u001b[0;32m     49\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_path\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[0;32m     50\u001b[0m     affine_transform, count_bboxes_overlapping_bbox, update_path_extents)\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\gpucc\\lib\\site-packages\\numpy\\linalg\\__init__.py:73\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m      2\u001b[0m \u001b[38;5;124;03m``numpy.linalg``\u001b[39;00m\n\u001b[0;32m      3\u001b[0m \u001b[38;5;124;03m================\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     70\u001b[0m \n\u001b[0;32m     71\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m     72\u001b[0m \u001b[38;5;66;03m# To get sub-modules\u001b[39;00m\n\u001b[1;32m---> 73\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m linalg\n\u001b[0;32m     74\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlinalg\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[0;32m     76\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39m__all__\u001b[38;5;241m.\u001b[39mcopy()\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\gpucc\\lib\\site-packages\\numpy\\linalg\\linalg.py:21\u001b[0m\n\u001b[0;32m     18\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01moperator\u001b[39;00m\n\u001b[0;32m     19\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mwarnings\u001b[39;00m\n\u001b[1;32m---> 21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[0;32m     22\u001b[0m     array, asarray, zeros, empty, empty_like, intc, single, double,\n\u001b[0;32m     23\u001b[0m     csingle, cdouble, inexact, complexfloating, newaxis, \u001b[38;5;28mall\u001b[39m, Inf, dot,\n\u001b[0;32m     24\u001b[0m     add, multiply, sqrt, \u001b[38;5;28msum\u001b[39m, isfinite,\n\u001b[0;32m     25\u001b[0m     finfo, errstate, geterrobj, moveaxis, amin, amax, product, \u001b[38;5;28mabs\u001b[39m,\n\u001b[0;32m     26\u001b[0m     atleast_2d, intp, asanyarray, object_, matmul,\n\u001b[0;32m     27\u001b[0m     swapaxes, divide, count_nonzero, isnan, sign, argsort, sort,\n\u001b[0;32m     28\u001b[0m     reciprocal\n\u001b[0;32m     29\u001b[0m )\n\u001b[0;32m     30\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmultiarray\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m normalize_axis_index\n\u001b[0;32m     31\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moverrides\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m set_module\n",
      "\u001b[1;31mImportError\u001b[0m: cannot import name 'array' from 'numpy.core' (unknown location)"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn import svm\n",
    "from scipy.stats import multivariate_normal\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "np.random.seed(42)\n",
    "\n",
    "# Generate normal data points\n",
    "X_normal = np.random.multivariate_normal(mean=[2, 2], cov=[[1, 0.5], [0.5, 1]], size=200)\n",
    "# Add a few outliers\n",
    "X_outliers = np.random.uniform(low=-2, high=6, size=(20, 2))\n",
    "# Combine all data for visualization\n",
    "X_combined = np.vstack([X_normal, X_outliers])\n",
    "\n",
    "# Create mesh grid for visualization\n",
    "x_min, x_max = -3, 7\n",
    "y_min, y_max = -3, 7\n",
    "xx, yy = np.meshgrid(np.linspace(x_min, x_max, 500), np.linspace(y_min, y_max, 500))\n",
    "grid_points = np.c_[xx.ravel(), yy.ravel()]\n",
    "\n",
    "########################################\n",
    "# Deep Bayesian Network simulation\n",
    "########################################\n",
    "# Estimate mean and covariance from normal data\n",
    "mean_est = np.mean(X_normal, axis=0)\n",
    "cov_est = np.cov(X_normal.T)\n",
    "\n",
    "# Create a function that simulates Bayesian uncertainty\n",
    "# by adding small random perturbations to the covariance matrix\n",
    "def bayesian_log_prob(x, y, n_samples=10):\n",
    "    points = np.column_stack([x.flatten(), y.flatten()])\n",
    "    log_probs = np.zeros(len(points))\n",
    "    \n",
    "    for i in range(n_samples):\n",
    "        # Add small random perturbation to covariance (simulating posterior samples)\n",
    "        perturb = np.random.normal(0, 0.1, size=cov_est.shape)\n",
    "        perturb = perturb @ perturb.T  # Ensure positive semi-definite\n",
    "        cov_sample = cov_est + 0.1 * perturb\n",
    "        \n",
    "        # Ensure the covariance matrix is positive definite\n",
    "        min_eig = np.min(np.real(np.linalg.eigvals(cov_sample)))\n",
    "        if min_eig < 0:\n",
    "            cov_sample -= 1.1 * min_eig * np.eye(*cov_sample.shape)\n",
    "            \n",
    "        # Compute log probability under this sample\n",
    "        mvn = multivariate_normal(mean=mean_est, cov=cov_sample)\n",
    "        log_probs += mvn.logpdf(points)\n",
    "    \n",
    "    # Average the log probabilities\n",
    "    log_probs /= n_samples\n",
    "    return log_probs.reshape(x.shape)\n",
    "\n",
    "# Compute Bayesian log probabilities\n",
    "Z_bayes = bayesian_log_prob(xx, yy)\n",
    "\n",
    "# Compute threshold for Bayesian method\n",
    "normal_log_probs = bayesian_log_prob(X_normal[:, 0].reshape(-1, 1), \n",
    "                                     X_normal[:, 1].reshape(-1, 1))\n",
    "bayes_threshold = np.percentile(normal_log_probs, 5)  # Lower percentile as these are log-probs\n",
    "\n",
    "########################################\n",
    "# Attention-Guided Cleaning methodology\n",
    "########################################\n",
    "# Simulate attention weights based on distance from the mean\n",
    "mean_normal = np.mean(X_normal, axis=0)\n",
    "\n",
    "# Function to compute attention weights\n",
    "def compute_attention(points, mean_center, temperature=0.5):\n",
    "    # Compute squared distance from mean\n",
    "    dist_sq = np.sum((points - mean_center)**2, axis=1)\n",
    "    \n",
    "    # Apply softmax-like function to get attention weights\n",
    "    # Lower temperature makes attention more focused\n",
    "    weights = np.exp(-dist_sq / temperature)\n",
    "    weights = weights / np.max(weights)  # Normalize to [0, 1]\n",
    "    \n",
    "    return weights\n",
    "\n",
    "# Compute attention weights for grid points\n",
    "grid_attention = compute_attention(grid_points, mean_normal)\n",
    "\n",
    "# Simulate a cleaning process based on attention\n",
    "# Higher attention = more likely to be considered normal\n",
    "# We'll use a threshold on attention for anomaly detection\n",
    "attention_threshold = 0.3\n",
    "\n",
    "# For visualization, we'll create a grid showing attention weights\n",
    "Z_attention = grid_attention.reshape(xx.shape)\n",
    "\n",
    "# Create a decision function that combines attention with distance\n",
    "# This simulates how attention guides the cleaning process\n",
    "def attention_guided_decision(points, mean_center, cov_matrix, attention_weights):\n",
    "    # Mahalanobis distance (statistical distance accounting for covariance)\n",
    "    cov_inv = np.linalg.inv(cov_matrix)\n",
    "    diff = points - mean_center\n",
    "    mahal_dist = np.sum(diff @ cov_inv * diff, axis=1)\n",
    "    \n",
    "    # Combine with attention weights (higher attention = lower anomaly score)\n",
    "    # This simulates how attention guides the cleaning process\n",
    "    decision_values = mahal_dist * (1 - attention_weights)\n",
    "    \n",
    "    return decision_values\n",
    "\n",
    "# Compute covariance of normal data\n",
    "cov_normal = np.cov(X_normal.T)\n",
    "\n",
    "# Compute decision values for grid points\n",
    "grid_decision = attention_guided_decision(grid_points, mean_normal, cov_normal, grid_attention)\n",
    "Z_decision = grid_decision.reshape(xx.shape)\n",
    "\n",
    "# Compute threshold based on normal data\n",
    "normal_attention = compute_attention(X_normal, mean_normal)\n",
    "normal_decision = attention_guided_decision(X_normal, mean_normal, cov_normal, normal_attention)\n",
    "decision_threshold = np.percentile(normal_decision, 95)\n",
    "\n",
    "########################################\n",
    "# Plotting side by side comparison\n",
    "########################################\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
    "\n",
    "# Plot 1: Deep Bayesian Network\n",
    "ax = axes[0]\n",
    "levels_bayes = np.linspace(Z_bayes.min(), bayes_threshold, 7)\n",
    "ax.contourf(xx, yy, Z_bayes, levels=levels_bayes, cmap='Greys', alpha=0.8)\n",
    "ax.contour(xx, yy, Z_bayes, levels=[bayes_threshold], linewidths=2, colors='black')\n",
    "ax.contourf(xx, yy, Z_bayes, levels=[bayes_threshold, Z_bayes.max()], colors=['#d3d3d3'], alpha=0.8)\n",
    "ax.scatter(X_normal[:, 0], X_normal[:, 1], facecolors='none', edgecolors='black', s=30, label='Normal Data')\n",
    "ax.scatter(X_outliers[:, 0], X_outliers[:, 1], c='black', s=50, marker='x', label='Outliers')\n",
    "ax.set_title('Deep Bayesian Network', fontsize=14)\n",
    "ax.set_xlabel('Feature 1')\n",
    "ax.set_ylabel('Feature 2')\n",
    "ax.legend(loc='upper right')\n",
    "ax.set_xlim([x_min, x_max])\n",
    "ax.set_ylim([y_min, y_max])\n",
    "\n",
    "# Add mathematical formulation\n",
    "ax.text(0.5, -0.15, r\"$p(\\mathbf{x}|\\mathcal{D}) = \\int p(\\mathbf{x}|\\\theta)p(\\\theta|\\mathcal{D})d\\\theta$\", \n",
    "        transform=ax.transAxes, ha='center', va='center', fontsize=12, \n",
    "        bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'))\n",
    "\n",
    "# Plot 2: Attention-Guided Cleaning\n",
    "ax = axes[1]\n",
    "contour_levels = np.linspace(Z_decision.min(), Z_decision.max(), 11)\n",
    "cf = ax.contourf(xx, yy, Z_decision, levels=contour_levels, cmap='Greys', alpha=0.8)\n",
    "ax.contour(xx, yy, Z_decision, levels=[decision_threshold], linewidths=2, colors='black')\n",
    "ax.contourf(xx, yy, Z_decision, levels=[0, decision_threshold], colors=['#d3d3d3'], alpha=0.8)\n",
    "\n",
    "# Plot normal data\n",
    "ax.scatter(X_normal[:, 0], X_normal[:, 1], facecolors='none', edgecolors='black', s=30, label='Normal Data')\n",
    "\n",
    "# Plot outliers with varying transparency based on attention\n",
    "for i, point in enumerate(X_outliers):\n",
    "    # Compute attention for this outlier\n",
    "    att = compute_attention(point.reshape(1, -1), mean_normal)[0]\n",
    "    \n",
    "    # Determine if it would be cleaned (high attention) or flagged (low attention)\n",
    "    if att > attention_threshold:\n",
    "        # This would be \"cleaned\" - show with arrow to where it would be moved\n",
    "        # Find the nearest point on the decision boundary\n",
    "        direction = mean_normal - point\n",
    "        direction = direction / np.linalg.norm(direction)\n",
    "        \n",
    "        # Draw arrow showing cleaning direction\n",
    "        ax.arrow(point[0], point[1], direction[0] * 0.5, direction[1] * 0.5, \n",
    "                 head_width=0.15, head_length=0.3, fc='blue', ec='blue', alpha=0.7)\n",
    "        \n",
    "        # Show original point as a blue circle\n",
    "        ax.scatter(point[0], point[1], c='blue', s=50, marker='o', alpha=0.7)\n",
    "    else:\n",
    "        # This would be flagged as anomaly - show with X\n",
    "        ax.scatter(point[0], point[1], c='red', s=50, marker='x')\n",
    "\n",
    "# Add legend\n",
    "normal_patch = patches.Patch(color='#d3d3d3', label='Normal Region')\n",
    "cleaned_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Cleaned Point')\n",
    "anomaly_point = plt.Line2D([0], [0], marker='x', color='red', markersize=10, label='Flagged Anomaly')\n",
    "cleaning_arrow = plt.Line2D([0], [0], color='blue', lw=2, label='Cleaning Direction')\n",
    "ax.legend(handles=[normal_patch, cleaned_point, anomaly_point, cleaning_arrow], loc='upper right')\n",
    "\n",
    "ax.set_title('Attention-Guided Cleaning', fontsize=14)\n",
    "ax.set_xlabel('Feature 1')\n",
    "ax.set_ylabel('Feature 2')\n",
    "ax.set_xlim([x_min, x_max])\n",
    "ax.set_ylim([y_min, y_max])\n",
    "\n",
    "# Add mathematical formulation\n",
    "ax.text(0.5, -0.15, r\"$f(\\mathbf{x}) = d_M(\\mathbf{x}) \\cdot (1 - a(\\mathbf{x}))$\", \n",
    "        transform=ax.transAxes, ha='center', va='center', fontsize=12, \n",
    "        bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'))\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('dbn_vs_agc.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print('Side-by-side comparison of Deep Bayesian Network and Attention-Guided Cleaning created.')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpucc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
